{
    "cells": [
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import torch\n",
                "import torchvision.transforms as T\n",
                "import matplotlib.pyplot as plt\n",
                "from PIL import Image\n",
                "\n",
                "# from model.networks_tf import Generator\n",
                "from model.networks import Generator\n",
                "\n",
                "plt.rcParams['figure.facecolor'] = 'white'"
            ]
        },
        {
            "attachments": {},
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Load generator model\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "use_cuda_if_available = False\n",
                "device = torch.device('cuda' if torch.cuda.is_available() \n",
                "                             and use_cuda_if_available else 'cpu')\n",
                "\n",
                "# sd_path = 'pretrained/states_tf_places2.pth'\n",
                "sd_path = 'pretrained/states_pt_places2.pth'\n",
                "\n",
                "generator = Generator(checkpoint=sd_path, return_flow=True).to(device)"
            ]
        },
        {
            "attachments": {},
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Load image and mask\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "n = 1\n",
                "image_path = f\"examples/inpaint/case{n}.png\"\n",
                "mask_path = f\"examples/inpaint/case{n}_mask.png\"\n",
                "\n",
                "image_pil = Image.open(image_path)\n",
                "mask_pil = Image.open(mask_path)"
            ]
        },
        {
            "attachments": {},
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Plot raw image and mask\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "_, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10))\n",
                "ax1.imshow(image_pil)  # plot raw image\n",
                "ax2.imshow(mask_pil.convert('L'))   # plot mask\n",
                "plt.show()"
            ]
        },
        {
            "attachments": {},
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Inpaint"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "image = T.ToTensor()(image_pil).to(device)\n",
                "mask = T.ToTensor()(mask_pil).to(device)\n",
                "\n",
                "output = generator.infer(image, mask, return_vals=['inpainted', 'stage1', 'stage2', 'flow'])\n"
            ]
        },
        {
            "attachments": {},
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Results\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "mask = (mask[0:1] > 0.).to(dtype=torch.float32)\n",
                "image_masked = (image*(1-mask)).permute(1, 2, 0).cpu()\n",
                "\n",
                "print(\"Result:\")\n",
                "plt.figure(figsize=(10, 8))\n",
                "plt.imshow(output[0])\n",
                "plt.show()\n",
                "\n",
                "w, h = image_pil.size\n",
                "\n",
                "print(\"\"\"| Raw | Masked |\n",
                "| Stage1 | Stage2 |\"\"\")\n",
                "_, axes = plt.subplots(2, 2, figsize=(15*w / max(w,h), 15*h / max(w,h)))\n",
                "axes[0,0].imshow(image_pil)\n",
                "axes[0,1].imshow(image_masked)\n",
                "axes[1,0].imshow(output[1])\n",
                "axes[1,1].imshow(output[2])\n",
                "plt.show()"
            ]
        },
        {
            "attachments": {},
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Plot attention flow map\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "plt.imshow(output[3][0].cpu().permute(1, 2, 0))"
            ]
        },
        {
            "attachments": {},
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "## Test Contextual Attention\n"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "\n",
                "from model.networks import ContextualAttention\n",
                "#from model.networks_tf import ContextualAttention\n",
                "\n",
                "\n",
                "contextual_attention = ContextualAttention(ksize=3, stride=1, rate=2, \n",
                "                                           fuse_k=3, softmax_scale=10, \n",
                "                                           fuse=False,\n",
                "                                           return_flow=True)\n",
                "\n",
                "imageB = 'examples/style_transfer/bike.jpg'\n",
                "imageA = 'examples/style_transfer/bnw_butterfly.png'\n",
                "\n",
                "def test_contextual_attention(imageA, imageB):\n",
                "    \"\"\"Test contextual attention layer with 3-channel image input\n",
                "    (instead of n-channel feature).\n",
                "    \n",
                "    \"\"\"\n",
                "    rate = 2\n",
                "    stride = 1\n",
                "    grid = rate*stride\n",
                "    \n",
                "    b = Image.open(imageA)\n",
                "    b = b.resize((b.width//2, b.height//2), resample=Image.BICUBIC)\n",
                "    b = T.ToTensor()(b)\n",
                "\n",
                "    _, h, w = b.shape\n",
                "    b = b[:, :h//grid*grid, :w//grid*grid].unsqueeze(0)\n",
                "\n",
                "    print(f\"Size of imageA: {b.shape}\")\n",
                "\n",
                "    f = T.ToTensor()(Image.open(imageB)) \n",
                "    _, h, w = f.shape\n",
                "    f = f[:, :h//grid*grid, :w//grid*grid].unsqueeze(0)\n",
                " \n",
                "    print(f\"Size of imageB: {f.shape}\")   \n",
                "\n",
                "    yt, flow = contextual_attention(f*255., b*255.)\n",
                "\n",
                "    return yt, flow\n",
                "\n",
                "\n",
                "yt, flow = test_contextual_attention(imageA, imageB)\n",
                "\n",
                "_, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 10))\n",
                "ax1.imshow(yt[0].permute(1, 2, 0)/255.)\n",
                "ax2.imshow(flow[0].permute(1, 2, 0))\n"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Python 3.9.12 ('base')",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.9.12"
        },
        "orig_nbformat": 4,
        "vscode": {
            "interpreter": {
                "hash": "f2e710066212467aef4080e32cd195a064f9d11c337be8995a8a23c3dcd42ac2"
            }
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
